from functools import partial
from tensorflow.keras.layers import Conv2D, Add, ZeroPadding2D, UpSampling2D, concatenate, MaxPooling2D, Lambda, add
from tensorflow.keras.layers import LeakyReLU,ReLU
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.regularizers import l2
import sys
sys.path.append("./")
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import MobileNetV2
import tensorflow.keras.backend as K
from resnet50 import ResNet50

def losses(y_true,y_pred):
    hm = y_true[...,:-4]
    wh = y_true[...,-4:-2]
    reg = y_true[...,-2:]
    y_true = {'hm':hm,'wh':wh,'reg':reg}
    hm_output =  y_pred[...,:-4]
    wh_output =  y_pred[...,-4:-2]
    reg_output = y_pred[...,-2:]
    output ={ 'hm':hm_output,'wh':wh_output,'reg':reg_output}
    mask = tf.equal(y_true['hm'],1)
    mask = tf.cast(mask,tf.float32)
    mask_reg = tf.reduce_max(mask,axis = -1)
    mask_reg = tf.expand_dims(mask_reg,-1)
    mask_reg = tf.concat([mask_reg,mask_reg],-1)
    N = tf.reduce_sum(mask,1)
    N = tf.reduce_sum(N,1)
    N = tf.reduce_sum(N,1)
    #output['hm'] = tf.nn.sigmoid(output['hm'])
    loss_hm_pos = -1.0*tf.pow(1.-output['hm'],2.)*tf.log(output['hm']+1e-12) * mask
    loss_hm_neg = -1.0*tf.pow(1.-y_true['hm'],4)*tf.pow(output['hm'],2)*tf.log(1.-output['hm']+1e-12)*(1.-mask)
    loss_hm = tf.reduce_sum(loss_hm_pos+loss_hm_neg,axis=1)
    loss_hm = tf.reduce_sum(loss_hm,axis=1)
    loss_hm = tf.reduce_sum(loss_hm,axis=1)/N
    loss_wh = tf.abs(y_true['wh']-output['wh']) * mask_reg
    loss_wh = tf.reduce_sum(loss_wh,axis=1)
    loss_wh = tf.reduce_sum(loss_wh,axis=1)
    loss_wh = tf.reduce_sum(loss_wh,axis=1)/N
    loss_reg = tf.abs(y_true['reg']-output['reg']) * mask_reg
    loss_reg = tf.reduce_sum(loss_reg,axis=1)
    loss_reg = tf.reduce_sum(loss_reg,axis=1)
    loss_reg = tf.reduce_sum(loss_reg,axis=1)/N
    loss_total =loss_hm+1.0*loss_wh+loss_reg
    return loss_hm,loss_wh,loss_reg,loss_total



#--------------------------------------------------#
#   GET_C获取featurmap
#--------------------------------------------------#
def GET_C(input_shape=[256,256,3],classes=80):
    return ResNet50(input_shape,classes)

#--------------------------------------------------#
#   FPN
#--------------------------------------------------#
def FPN(C2,C3,C4,C5):
    P5 = Conv2D(256,1,padding='same')(C5)
    P4 = add([UpSampling2D()(P5),Conv2D(256,1,padding='same')(C4)])
    P3 = add([UpSampling2D()(P4),Conv2D(256,1,padding='same')(C3)])
    P2 = add([UpSampling2D()(P3),Conv2D(256,1,padding='same')(C2)])
    P6 = Conv2D(256,3,padding='same')(P5)
    return  [P2,P3,P4,P5,P6]


def merge_layers(big_layer,small_layer):
    size = int(big_layer.shape[1])//int(small_layer.shape[1])
    n_filter=int(big_layer.shape[-1])
    big_layer_conv =  Conv2D(n_filter,3,padding='same')(big_layer)  
    small_layer_up = UpSampling2D(size=(size,size))(small_layer)
    small_layer_up = Conv2D(n_filter,3,padding='same')(small_layer_up)  
    out = concatenate([big_layer_conv,small_layer_up],axis=-1)
    out =  Conv2D(n_filter,3,padding='same')(out)  
    out = BatchNormalization()(out)
    out = LeakyReLU(0.2)(out)
    return out
    

def center_branch(feat,n_classes=9):
    #hm
    hm = Conv2D(256,3,padding='same')(feat)   #slim.conv2d(feat,256,[3,3])
    hm = ReLU()(hm)
    hm = Conv2D(n_classes,1,padding='same',activation='sigmoid')(hm)  #slim.conv2d(hm,n_classes,[1,1],activation_fn=tf.sigmoid)
    #WH
    wh = Conv2D(256,3,padding='same')(feat)
    wh = ReLU()(wh)
    wh = Conv2D(2,1,padding='same')(hm)  #slim.conv2d(wh,2,[1,1],activation_fn=None)
    #reg
    reg = Conv2D(256,3,padding='same')(feat)
    reg = ReLU()(reg)
    reg = Conv2D(2,1,padding='same')(hm)  #slim.conv2d(reg,2,[1,1],activation_fn=None)
    output = concatenate([hm,wh,reg],axis=-1)
    return output#[hm,wh,reg]#{'hm':hm,'wh':wh,'reg':reg}


#--------------------------------------------------#
#   build_model
#--------------------------------------------------#
def build_model(input_shape=[256,256,3],classes=10):
    img_input = Input(shape=input_shape)
    
    # 获取res50特征层
    C2,C3,C4,C5 = GET_C(img_input,classes)
    #DLA-34过程
    #stage1
    C2_1 = merge_layers(C2,C3)
    C3_1 = merge_layers(C3,C4)
    C4_1 = merge_layers(C4,C5)

    #stage2
    C2_2 = merge_layers(C2_1,C3_1)
    C3_2 = merge_layers(C3_1,C4_1)
    
    #stage3
    C2_3 =merge_layers(C2_2,C3_2)
    C3_3 = merge_layers(C2_3,C3_2)
    C4_3 = merge_layers(C3_3,C4_1)
    C5_3 = merge_layers(C4_3,C5)
    C5_3 = UpSampling2D()(C5_3)
    result = center_branch(C5_3,n_classes=10)
    label_input= Input(shape=result.shape[1:])

    # 添加损失函数layer
    loss_hm,loss_wh,loss_reg,loss_total =Lambda(lambda x:losses(*x))([label_input,result])
    # 给损失layer命名
    loss_hm = Lambda(lambda x:x,name='hm_loss')(loss_hm)
    loss_wh = Lambda(lambda x:x,name='wh_loss')(loss_wh)
    loss_reg = Lambda(lambda x:x,name='reg_loss')(loss_reg)
    loss_total = Lambda(lambda x:x,name='total_loss')(loss_total)

    model = Model(inputs= [img_input,label_input],outputs = [result, loss_hm,loss_wh,loss_reg,loss_total])                                                           
    model.add_loss(model.get_layer('hm_loss').output)
    model.add_loss(model.get_layer('wh_loss').output)
    model.add_loss(model.get_layer('reg_loss').output)
    model.add_loss(model.get_layer('total_loss').output)
    return model


if __name__ == "__main__":
    model = build_model()
    model.summary()
    print()
